// NOTE: loss is accumulated to whatever init value it already has.
//
// NOTE: the loss is a float pointer regardless of whether the input
// is double or float. This is because currently CUDA atomicAdd only
// support up to float.
template <typename T>
__device__ void binary_cross_entropy_loss_forward(T *prob, T *label, int dim, float *loss) {
  __shared__ T local_loss[THREADS_PER_BLOCK_X];

  int idx = threadIdx.x + blockIdx.x * blockDim.x;

  if (idx >= dim) {
    // out of bounds, set local loss to zero so that we can safely accumulate late
    local_loss[threadIdx.x] = 0;
  } else {
    T the_prob = prob[idx];
    T the_label = label[idx];
    //T log_prob = -log(max(the_prob, static_cast<T>(LOG_THRESHOLD))) * the_label;
    T log_prob = -log(the_prob) * the_label;
    //log_prob -= log1p(max(-the_prob, static_cast<T>(LOG_THRESHOLD) - 1)) * (1-the_label);
    log_prob -= log1p(-the_prob) * (1-the_label);

    local_loss[threadIdx.x] = log_prob;
  }

  __syncthreads();
  if (0 == threadIdx.x) {
    // thread 0 does in-block accumulation
    T total_local_loss = 0;
    for (int i = 0; i < THREADS_PER_BLOCK_X; ++i)
      total_local_loss += local_loss[i];
    atomicAdd(loss, static_cast<float>(total_local_loss));
  }
}

template <typename T>
__device__ void binary_cross_entropy_loss_backward(T *prob, T *label, int dim, T *gradient_pred, T *gradient_label, T weight) {

  int idx = threadIdx.x + blockIdx.x * blockDim.x;

  if (idx >= dim) {
    // out of bounds, set local loss to zero so that we can safely accumulate late
    if (gradient_pred) {
      gradient_pred[idx] = 0;
    }
    if (gradient_label) {
      gradient_label[idx] = 0;
    }
  } else {
    T the_pred = prob[idx];
    T the_label = label[idx];
    //T log_prob = -log(max(the_prob, static_cast<T>(LOG_THRESHOLD))) * the_label;
    //dl_dpred = (label ./ pred) - ((1-label) ./ (1-pred)) # dloss/d?
    if (gradient_pred) {
      gradient_pred[idx] = -weight*(the_label/the_pred - (1-the_label)/(1-the_pred));
    }
    if (gradient_label) {
      gradient_label[idx] = -weight*log(the_pred/(1-the_pred));
    }
  }


}


extern "C" {
  __global__ void binary_cross_entropy_loss_forward_float(float *prob, float *label, int dim, float *loss) {
    binary_cross_entropy_loss_forward(prob, label, dim, loss);
  }
  __global__ void binary_cross_entropy_loss_forward_double(double *prob, double *label, int dim, float *loss) {
    binary_cross_entropy_loss_forward(prob, label, dim, loss);
  }
  __global__ void binary_cross_entropy_loss_backward_float(float *prob, float *label, int dim, float *gradient_pred, float *gradient_label, float weight) {
    binary_cross_entropy_loss_backward(prob, label, dim, gradient_pred, gradient_label, weight);
  }
  __global__ void binary_cross_entropy_loss_backward_double(double *prob, double *label, int dim, double *gradient_pred, double *gradient_label, double weight) {
    binary_cross_entropy_loss_backward(prob, label, dim, gradient_pred, gradient_label, weight);
  }
}

// vim: ft=cuda
